### Cuda setup

In [None]:
import torch
import numpy as np
import random
import warnings

# Set random seeds.
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# %env CUBLAS_WORKSPACE_CONFIG=:4096:8
torch.use_deterministic_algorithms(True)

cuda = torch.cuda.is_available()

warnings.filterwarnings('ignore')

### Define dataset

In [None]:
from dataset import InstaDataset
from torch.utils.data import DataLoader
import torchvision.transforms as T

transform = T.Compose([
    T.ToTensor(),
    T.Resize(400),
    T.CenterCrop(400),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_target = False
pretrained = False

train_ds = InstaDataset('../train.json', '../data_new/', transform=transform, transform_target=transform_target)
val_ds = InstaDataset('../val.json', '../data_new/', transform=transform, transform_target=transform_target)

train_loader = DataLoader(train_ds, 
                          batch_size=16, 
                          num_workers=4,
                          shuffle=True,
                          pin_memory=True)
val_loader = DataLoader(val_ds, 
                          batch_size=16, 
                          num_workers=4,
                          pin_memory=True)

### Define logger

In [None]:
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login(relogin=True)

wandb_logger = WandbLogger(project='PopIn',
                           entity="ids_course",
                           log_model=True,
                           tags=['no_target_transform'],
                           settings=wandb.Settings(start_method="thread"))

### Define training pipeline

In [None]:
from lightning import Insta
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

model = Insta(dim=32, lr=1e-5, weight_decay=1e-5, transform_target=transform_target, pretrained=pretrained)
wandb_logger.watch(model, log="all", log_graph=True, log_freq=10)
trainer = Trainer(devices=1, accelerator="gpu", logger=wandb_logger, max_epochs=10, log_every_n_steps=10)
trainer.fit(model, train_loader, val_loader)
wandb.finish()