In [None]:
import torch
from torch import nn
import os
import torch.optim as optim
import time
from tqdm import tqdm
import datasets
from torch.utils.data import DataLoader

class CLIP_Adapted_ImageEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
        
        self.adapter = nn.Sequential(
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            
            nn.LeakyReLU(),
            
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            
            nn.LeakyReLU(),
            
            nn.Linear(512, 512),
            nn.LayerNorm(512)
        )

    def forward(self, image):
        return self.adapter(self.clip_model.encode_image(image))
    
class CLIP_Adapted_TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
        self.adapter = nn.Sequential(
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            
            nn.LeakyReLU(),
            
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            
            nn.LeakyReLU(),
            
            nn.Linear(512, 512),
            nn.LayerNorm(512)
        )

    def forward(self, text):
        return self.adapter(self.clip_model.encode_text(text))
    
class DA_adapter(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        for param in clip_model.parameters():
            param.requires_grad = False
        self.ad_imageEncoder = CLIP_Adapted_ImageEncoder(clip_model=clip_model)
        self.ad_textEncoder = CLIP_Adapted_TextEncoder(clip_model=clip_model)
        
    def forward(self, image, text):
        image = self.ad_imageEncoder(image)
        text = self.ad_textEncoder(text)
        return image, text

In [None]:
from argparse import ArgumentParser
import model_resume
import open_clip

parser = ArgumentParser(description='LDUN')
parser.add_argument('--about', type=str, default='5task')
parser.add_argument('--start_epoch', type=int, default=0, help='epoch number of start training')
parser.add_argument('--end_epoch', type=int, default=100, help='epoch number of end training')
parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate')
parser.add_argument('--resume', type=bool, default=False, help='is resume')
parser.add_argument('--group_num', type=int, default=1, help='group number for training')
parser.add_argument('--gpu_list', type=str, default='1', help='gpu index')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/2/', help='checkpoints dir')
parser.add_argument('--log_dir', type=str, default='log', help='log directory')
parser.add_argument('--ext', type=str, default='.png', help='training data directory')
parser.add_argument('--is_aug', type=bool,default=False, help='is aug')
parser.add_argument('--is_clip_tuning', type=bool,default=False, help='is finetuning clip')
parser.add_argument('--patch_size', type=int, default=128, help='patchsize of input.')

# Noise, Haze, Rain, Blurr, Lowlight
NHRBL = ["./datasets/denoising_datasets/15_train_paths.txt",
 "./datasets/denoising_datasets/25_train_paths.txt",
 "./datasets/denoising_datasets/50_train_paths.txt",
 "./datasets/dehazing_datasets/train_paths.txt",
 "./datasets/deraining_datasets/Rain100L/train_paths.txt",
 "./datasets/deblurring_datasets/GoPro/train_paths.txt",
 "./datasets/delowlight_datasets/LoL/train_paths.txt"]

# Noise, Haze, Rain
NHR = ["./datasets/denoising_datasets/15_train_paths.txt",
 "./datasets/denoising_datasets/25_train_paths.txt",
 "./datasets/denoising_datasets/50_train_paths.txt",
 "./datasets/dehazing_datasets/train_paths.txt",
 "./datasets/deraining_datasets/Rain100L/train_paths.txt"]

NHR_text = ["image with noise",
            "image with many noise",
            "image with large amount of noise",
            "image with haze",
            "image with rain"]

NHRBL_text = ["image with noise",
            "image with many noise",
            "image with large amount of noise",
            "image with haze",
            "image with rain",
            "image with blur",
            "image with low light"]

In [None]:
args = parser.parse_args(args=[])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_list
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model, preprocess, _ = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

model = DA_adapter(model)
model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999),eps=1e-8)

# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(f"{name} is learnable")
#     else:
#         print(f"{name} is frozen")

start_epoch = args.start_epoch
new_lr = args.learning_rate
######### Resume ###########
if args.resume:
    path_chk_rest = model_resume.get_last_path(args.checkpoints_dir, '_best.pth')
    model_resume.load_checkpoint(model,path_chk_rest)
    start_epoch = model_resume.load_start_epoch(path_chk_rest) + 1
    model_resume.load_optim(optimizer, path_chk_rest)
    
    # for i in range(1, start_epoch):
    #     scheduler.step()
    # new_lr = scheduler.get_lr()[0]
    
######### Loss ###########
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.L1Loss()

In [None]:
from utils.clip_dataset import *
trainset = TrainDataset_forIR(NHRBL, args)
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=16)

In [None]:
texts = tokenizer(NHRBL_text).to(device)
for epoch in range(args.end_epoch):
    epoch_loss = 0
    epoch_start_time = time.time()
    for i, data in enumerate(tqdm(train_loader), 0):
        images = data[2].to(device) # [b, 3, 224, 224]
        labels = data[3].to(device) # 
        
        image_features, text_features = model(images, texts)

        logits_per_image = image_features @ text_features.T
        # selected_text_features = text_features[labels]
        
        # loss = criterion1(logits_per_image, labels) + criterion2(image_features, selected_text_features)
        loss = criterion1(logits_per_image, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss +=loss.item()
        if i % 2500 == 0:
            text_probs = (100.0 * logits_per_image).softmax(dim=-1)
            print(f"Text Probs: {text_probs}")
            print(f"Logits per Image: {logits_per_image}")
            print(f"Lables: {labels}")
    epoch_end_time = time.time()
    print(f"Epoch {epoch+1}, Epoch Loss: {epoch_loss}, Time: {(epoch_end_time-epoch_start_time)}")
    
    if epoch % 10 == 0:
        learnable_params = {name: param for name, param in model.named_parameters() if param.requires_grad}
        torch.save({'epoch': epoch, 
                    'learnable_params': learnable_params,
                    'optimizer' : optimizer.state_dict()
                    }, os.path.join(args.checkpoints_dir,f"model_epoch_{epoch}.pth"))