<a href="https://colab.research.google.com/github/thegenerativegeneration/CLIP/blob/main/tgg_CLIP_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!nvidia-smi

In [None]:
import subprocess
simple_nvidia_smi_display = False#@param {type:"boolean"}
if simple_nvidia_smi_display:
    #!nvidia-smi
    nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
else:
    #!nvidia-smi -i 0 -e 0
    nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
    nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_ecc_note)


import subprocess, os, sys, ipykernel

def gitclone(url, targetdir=None):
    if targetdir:
        res = subprocess.run(['git', 'clone', url, targetdir], stdout=subprocess.PIPE).stdout.decode('utf-8')
    else:
        res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def pipi(modulestr):
    res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def pipie(modulestr):
    res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

def wget(url, outputdir):
    res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)

try:
    from google.colab import drive
    print("Google Colab detected. Using Google Drive.")
    is_colab = True
    #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.
    google_drive = True #@param {type:"boolean"}
    #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:
    save_models_to_google_drive = True #@param {type:"boolean"}
except:
    is_colab = False
    google_drive = False
    save_models_to_google_drive = False
    print("Google Colab not detected.")

if is_colab:
    if google_drive is True:
        drive.mount('/content/drive')
        root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'
    else:
        root_path = '/content'
else:
    root_path = os.getcwd()

import os
def createPath(filepath):
    os.makedirs(filepath, exist_ok=True)


import pathlib, shutil, os, sys

# There are some reports that with a T4 or V100 on Colab, downgrading to a previous version of PyTorch may be necessary.
# .. but there are also reports that downgrading breaks them!  If you're facing issues, you may want to try uncommenting and running this code.
# nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')
# cards_requiring_downgrade = ["Tesla T4", "V100"]
# if is_colab:
#     if any(cardstr in nvidiasmi_output for cardstr in cards_requiring_downgrade):
#         print("Downgrading pytorch. This can take a couple minutes ...")
#         downgrade_pytorch_result = subprocess.run(['pip', 'install', 'torch==1.10.2', 'torchvision==0.11.3', '-q'], stdout=subprocess.PIPE).stdout.decode('utf-8')
#         print("pytorch downgraded.")

#@markdown Check this if you want to use CPU
useCPU = False #@param {type:"boolean"}

if not is_colab:
    # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.
    os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

PROJECT_DIR = os.path.abspath(os.getcwd())
USE_ADABINS = True

if is_colab:
    if not google_drive:
        root_path = f'/content'
        model_path = '/content/models' 
else:
    root_path = os.getcwd()
    model_path = f'{root_path}/models'

multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'einops', 'pytorch-lightning', 'omegaconf'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(multipip_res)

if is_colab:
    subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')

try:
    from CLIP import clip
except:
    if not os.path.exists("CLIP"):
        gitclone("https://github.com/openai/CLIP")
    sys.path.append(f'{PROJECT_DIR}/CLIP')



import torch
from dataclasses import dataclass
from functools import partial
import cv2
import pandas as pd
import gc
import io
import math
import timm
from IPython import display
import lpips
from PIL import Image, ImageOps
import requests
from glob import glob
import json
from types import SimpleNamespace
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
from CLIP import clip
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import random
from ipywidgets import Output
import hashlib
from functools import partial

from IPython.display import Image as ipyimg
from numpy import asarray
from einops import rearrange, repeat
import torch, torchvision
import time
from omegaconf import OmegaConf
import warnings

DEVICE = torch.device('cuda:0' if (torch.cuda.is_available() and not useCPU) else 'cpu')
print('Using device:', DEVICE)
device = DEVICE # At least one of the modules expects this name..


if not useCPU:
    if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad
        print('Disabling CUDNN for A100 gpu', file=sys.stderr)
        torch.backends.cudnn.enabled = False



In [None]:
!cd /content && unzip /content/drive/MyDrive/data/clip_data_v9.zip

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from IPython.display import HTML, display
import os


dataset_root = "/content/clip_data"#@param
checkpoint_path = "/content/drive/MyDrive/AI/clip_checkpoints_v9"#@param

os.makedirs(checkpoint_path, exist_ok=True)

BATCH_SIZE = 128#@param
EPOCH = 100#@param
model_suffix = "_ukiyoe"#@param
clip_model_name = "ViT-B/32"#@param ["ViT-B/32", "ViT-B/16", "ViT-L/14", "RN50", "RN50x4", "RN50x16", "RN50x64", "RN101", "ViT-L/14@336px"]
model_name = clip_model_name.replace('/', '') + model_suffix

In [None]:

model, preprocess = clip.load(clip_model_name,device=device,jit=False) #Must set jit=False for training


In [None]:
#@markdown Choose which layers you don't want to train. These layers will have it's gradient disabled and only unchecked layers will be trained. This corresponds to Vit-B/32 architecture.

freeze_visual = False#@param{type: 'boolean'}
freeze_token_embedding = True#@param{type: 'boolean'}
freeze_transformer = False#@param{type: 'boolean'}
freeze_ln_final = True#@param{type: 'boolean'}
for child in model.named_children():
  if freeze_visual and child[0] == "visual":
    print(f"Freezing: {child[0]}")
    for param in child[1].parameters():
      param.requires_grad = False
  if freeze_token_embedding and child[0] == "token_embedding":
    print(f"Freezing: {child[0]}")
    for param in child[1].parameters():
      param.requires_grad = False
  if freeze_transformer and child[0] == "transformer":
    print(f"Freezing: {child[0]}")
    for param in child[1].parameters():
      param.requires_grad = False
  if freeze_ln_final and child[0] == "ln_final":
    print(f"Freezing: {child[0]}")
    for param in child[1].parameters():
      param.requires_grad = False

    

# Training

In [None]:
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomResizedCrop
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def progress(value, max=100):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 100%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))
def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        RandomResizedCrop(n_px, scale=(0.8,1.0), ratio=(1.0,1.0), interpolation=BICUBIC),
        RandomHorizontalFlip(0.5),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

class image_title_dataset(Dataset):
    def __init__(self, list_image_path,list_txt, transforms):

        self.image_path = list_image_path
        self.title  = clip.tokenize(list_txt, truncate=True) #you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.
        self.transforms = transforms

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = Image.open(self.image_path[idx])
        image = self.transforms(image) # Image from PIL module
        title = self.title[idx]
        return image,title

# use your own data
list_image_path = [] 
list_txt = []

print("Loading dataset")

source_images_path = f"{dataset_root}/images"
files = [f for f in os.listdir(source_images_path) if os.path.isfile(f"{source_images_path}/{f}") and (f.endswith(".png") or f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".webp"))]

ttl = len(files)
out = display(progress(0, ttl), display_id=True)
ix = 0
for f in files:
    out.update(progress(ix, ttl))
    ix+= 1
    im_data_number = f.split('.')[0]
    if not os.path.isfile(f"{dataset_root}/text/{im_data_number}.txt"):
        continue
    with open(f"{dataset_root}/text/{im_data_number}.txt") as textfile:
        lines = textfile.readlines()
    for line in lines:
      if line:
        image_path = f"{source_images_path}/{f}"
        list_image_path.append(image_path)
        list_txt.append(line)
    

dataset = image_title_dataset(list_image_path, list_txt, _transform(model.visual.input_resolution))
train_dataloader = DataLoader(dataset,batch_size = BATCH_SIZE, shuffle=True) #Define your own dataloader

#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        if p.grad is not None:
          p.grad.data = p.grad.data.float() 



loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=5e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.002) # reduced learning rate and weight decay - recommended by various papers

# add your own code to track the training progress.
print("Starting training")
for epoch in range(EPOCH+1):
  print(f"EPOCH {epoch}")
  epoch_loss = 0.0
  for i, batch in enumerate(train_dataloader):
      optimizer.zero_grad()

      images,texts = batch 
    
      images= images.to(device)
      texts = texts.to(device)
    
      logits_per_image, logits_per_text = model(images, texts)

      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
      img_loss = loss_img(logits_per_image,ground_truth)
      txt_loss= loss_txt(logits_per_text,ground_truth)

      total_loss = ( img_loss+txt_loss )/2
      total_loss.backward()

      if device == "cpu":
         optimizer.step()
      else : 
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)
      epoch_loss += total_loss.item()
      if i % 10 == 0:
        print(f'image loss: {img_loss.item():>7f} text loss: {txt_loss.item():>7f} total loss: {total_loss.item()} batch: {i+1}')
  print(f" Epoch loss: {epoch_loss}")
  if epoch % 1 == 0:
    print(f"Saving checkpoint: {checkpoint_path}/{model_name}_e{epoch}.pt")
    torch.save(model.state_dict(), f"{checkpoint_path}/{model_name}_e{epoch}.pt")
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizers_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            }, f"{checkpoint_path}/{model_name}_e{epoch}_full.pt")

In [None]:
!cat /content/clip_data_v6/text/7603-49.txt

# TODO: Wise-FT

In [None]:
alpha = 0.75
#finetuned_model = model

epoch = 20
finetuned_model_statedict = torch.load(f"/content/drive/MyDrive/AI/clip_checkpoints_v9/ViT-B32_ukiyoe_e{epoch}.pt", map_location=device) 
zeroshot_model, preprocess = clip.load(clip_model_name,device=device,jit=False) 

theta_0 = zeroshot_model.state_dict()
theta_1 = finetuned_model_statedict

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())


# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
#finetuned_model.load_state_dict(theta)


torch.save(theta, f"{checkpoint_path}/{model_name}_e{epoch}-wiseft_alpha-{alpha}.pt")

In [None]:
print(finetuned_model.__dict__)

In [None]:
torch.cuda.empty_cache()