<a href="https://colab.research.google.com/github/vanshtibrewal/CLIP-classifier/blob/main/Vision_Lab_Task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Install Requirements

In [1]:
!pip install pytorch_lightning wandb

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.1.2-py3-none-any.whl (776 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m776.9/776.9 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting wandb
  Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m27.0 MB/s[0m eta [36m0:00:

#Import Requirements

In [2]:
import torchvision
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader, TensorDataset
import os
from transformers import CLIPProcessor, CLIPModel
import pytorch_lightning as pl
import wandb
import getpass
from pytorch_lightning.loggers import WandbLogger

#Mount Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#Hyperparameters/Arbitrary Constants

In [5]:
batch_size = 32
train_test_split = 0.8
pl.seed_everything(42, workers=True)
torch.backends.cudnn.deterministic = True
num_workers = 2
lr = 1e-3
decay = 1e-5
step_size = 10
gamma = 0.1

INFO:lightning_fabric.utilities.seed:Seed set to 42


#Load/Download/Process/Save Data

##Initial Download, Preprocessing, and Saving - first time only

In [None]:
dataset_dir = '/content/drive/My Drive/VisionLab/Caltech101'
caltech_dataset = torchvision.datasets.Caltech101(root=dataset_dir, download=False)

In [None]:
save_dir = '/content/drive/My Drive/VisionLab/PreprocessedCaltech101'
os.makedirs(save_dir, exist_ok=True)

In [None]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

processed_images = []
labels = []

In [None]:
total_len = len(caltech_dataset)
for idx, (image, label) in enumerate(caltech_dataset):
    inputs = processor(images=image, return_tensors="pt")["pixel_values"].squeeze()
    processed_images.append(inputs)
    labels.append(label)
    if idx % 1000 == 0:
      print("Progress Update:", (100.0*idx/total_len))

Progress Update: 0.0
Progress Update: 11.524720525527256
Progress Update: 23.04944105105451
Progress Update: 34.574161576581766
Progress Update: 46.09888210210902
Progress Update: 57.62360262763628
Progress Update: 69.14832315316353
Progress Update: 80.67304367869079
Progress Update: 92.19776420421805


In [None]:
processed_images = torch.stack(processed_images)
labels = torch.tensor(labels)

In [None]:
torch.save(processed_images, os.path.join(save_dir, 'processed_images.pt'))
torch.save(labels, os.path.join(save_dir, 'labels.pt'))

##Load Preprocessed Data

In [6]:
save_dir = '/content/drive/My Drive/VisionLab/PreprocessedCaltech101'
images = torch.load(os.path.join(save_dir, 'processed_images.pt'))
labels = torch.load(os.path.join(save_dir, 'labels.pt'))

In [7]:
dataset = TensorDataset(images, labels)
train_size = int(train_test_split*len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#Create Model

In [9]:
class CLIPClassifier(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    self.encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 101))

    for param in self.encoder.parameters():
      param.requires_grad = False

  def forward(self, x):
    processed = self.processor(x, return_tensors="pt")["pixel_values"].squeeze()
    embedding = self.encoder.get_image_features(processed)
    out = self.classifier(embedding)
    out = F.softmax(out, dim=1)
    return torch.argmax(out, dim=1)

  def training_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def validation_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr, weight_decay=decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    return [optimizer], [scheduler]

# Wandb Setup


In [10]:
SECRET_KEY = getpass.getpass("API key:")
wandb.login(key=SECRET_KEY, relogin=True)

API key:··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

#Training

In [11]:
model = CLIPClassifier()
wandb_logger = WandbLogger(name=f"Base - lr {lr} weight decay {decay} step size {step_size} gamma {gamma}", project="CLIPClassifier")
trainer = pl.Trainer(max_epochs=500, logger=wandb_logger)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/568 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mvanshtibrewal[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type       | Params
------------------------------------------
0 | encoder    | CLIPModel  | 151 M 
1 | classifier | Sequential | 177 K 
------------------------------------------
177 K     Trainable params
151 M     Non-trainable params
151 M     Total params
605.818   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


KeyboardInterrupt: ignored

In [12]:
torch.save(model.state_dict(), '/content/drive/My Drive/VisionLab/base_model.pth')

In [16]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.014 MB uploaded\r'), FloatProgress(value=0.08598703923071563, max=1.…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███▁
train_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███▁
val_loss,█▄▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂

0,1
epoch,0.0
train_loss,0.03444
trainer/global_step,216.0
val_loss,0.21904


# Nano Model - Reduce Overfitting

In [17]:
class NanoCLIPClassifier(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    self.encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 101))

    for param in self.encoder.parameters():
      param.requires_grad = False

  def forward(self, x):
    processed = self.processor(x, return_tensors="pt")["pixel_values"].squeeze()
    embedding = self.encoder.get_image_features(processed)
    out = self.classifier(embedding)
    out = F.softmax(out, dim=1)
    return torch.argmax(out, dim=1)

  def training_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def validation_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr, weight_decay=decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    return [optimizer], [scheduler]

In [18]:
model_nano = NanoCLIPClassifier()
wandb_logger = WandbLogger(name=f"Nano - lr {lr} weight decay {decay} step size {step_size} gamma {gamma}", project="CLIPClassifier")
trainer = pl.Trainer(max_epochs=500, logger=wandb_logger)
trainer.fit(model_nano, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type       | Params
------------------------------------------
0 | encoder    | CLIPModel  | 151 M 
1 | classifier | Sequential | 157 K 
------------------------------------------
157 K     Trainable params
151 M     Non-trainable params
151 M     Total params
605.738   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [19]:
torch.save(model.state_dict(), '/content/drive/My Drive/VisionLab/nano_model.pth')

In [20]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.1060918946824987, max=1.0…

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,21.0
train_loss,0.0072
trainer/global_step,4773.0
val_loss,0.11131


# Nano Nano Model - Reduce Overfitting?

In [21]:
class NanoNanoCLIPClassifier(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    self.encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    self.classifier = nn.Sequential(nn.Linear(512, 101))

    for param in self.encoder.parameters():
      param.requires_grad = False

  def forward(self, x):
    processed = self.processor(x, return_tensors="pt")["pixel_values"].squeeze()
    embedding = self.encoder.get_image_features(processed)
    out = self.classifier(embedding)
    out = F.softmax(out, dim=1)
    return torch.argmax(out, dim=1)

  def training_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def validation_step(self, batch, batch_idx):
    processed_imgs, labels = batch
    z = self.encoder.get_image_features(processed_imgs)
    out = self.classifier(z)
    out = F.log_softmax(out, dim=1)
    loss = nn.NLLLoss()(out, labels)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr, weight_decay=decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    return [optimizer], [scheduler]

In [22]:
model_nano_nano = NanoNanoCLIPClassifier()
wandb_logger = WandbLogger(name=f"NanoNano - lr {lr} weight decay {decay} step size {step_size} gamma {gamma}", project="CLIPClassifier")
trainer = pl.Trainer(max_epochs=500, logger=wandb_logger)
trainer.fit(model_nano_nano, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type       | Params
------------------------------------------
0 | encoder    | CLIPModel  | 151 M 
1 | classifier | Sequential | 51.8 K
------------------------------------------
51.8 K    Trainable params
151 M     Non-trainable params
151 M     Total params
605.317   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [23]:
torch.save(model.state_dict(), '/content/drive/My Drive/VisionLab/nano_nano_model.pth')

In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))