In [1]:
import cv2
import numpy as np
from torchvision.datasets import CocoCaptions
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize

  from .autonotebook import tqdm as notebook_tqdm


## Unet

In [2]:
import Detection.GroundingDINO.groundingdino.datasets.transforms as T

transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )


In [3]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self,phase, transform):
        self.phase = phase
        
        if phase == 'train':
            self.dataset = CocoCaptions(root='./train2017', annFile='./annotations/captions_train2017.json')
        else:
            self.dataset = CocoCaptions(root='./val2017', annFile='./annotations/captions_val2017.json')
        
        self.transform = transform
        
        
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,idx):
        
        org_img, caps = self.dataset[idx]

        img,_ = self.transform(org_img, None)
        
        random_idx = torch.randint(0,len(caps), (1,))
        
        return np.array(org_img), img, caps[random_idx]


In [4]:
from SResolution.TG_Umodels import *

def select_model(down_scale, shape):
    if down_scale == 2:
        model = TG_UNet2(3,3, shape)
    elif down_scale == 4:
        model = TG_UNet4(3,3, shape)
    elif down_scale == 8:
        model = UNet8(3,3)

    return model

In [5]:


down_scale = 2
input_shape = (512//down_scale,512//down_scale)
model = select_model(down_scale,input_shape)

criterion = nn.MSELoss()
epochs = 200
lr = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)


cuda:0


In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler, DataLoader, Subset

trainDataset = CustomDataset('train', transform)

num_train_samples = 1000
sample_ds = Subset(trainDataset, np.arange(num_train_samples))
sample_sampler = RandomSampler(sample_ds)
# sample_dl = DataLoader(sample_ds, sampler=sample_sampler, batch_size=1)
train_dataloader = DataLoader(sample_ds, batch_size=1, shuffle=True)

valDataset = CustomDataset('val', transform)
num_train_samples = 200
sample_ds = Subset(valDataset, np.arange(num_train_samples))
sample_sampler = RandomSampler(sample_ds)
val_dataloader = DataLoader(sample_ds, batch_size=1, shuffle=False)

loading annotations into memory...
Done (t=0.63s)
creating index...
index created!
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [7]:
print(len(train_dataloader), len(val_dataloader))

1000 200


In [8]:
oo, a, c = next(iter(train_dataloader))

oo.shape, a.shape, type(c)

(torch.Size([1, 443, 640, 3]), torch.Size([1, 3, 800, 1155]), tuple)

In [9]:
from Detection.GroundingDINO.groundingdino.util.inference import load_model, load_image_6, predict, annotate
from tqdm import tqdm 
import time
from torchvision.ops import box_convert
from torchvision.transforms.functional import center_crop
from copy import deepcopy


iterable = range(epochs)

dect_model = load_model("./Detection/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./Detection/GroundingDINO//weights/groundingdino_swint_ogc.pth")
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25

target_size = 512

model.to(device)

best_loss = 999
best_result = None
for e in iterable:
    for p in ['Train', 'Test']:
        if p == 'Train':
            total_loss = 0
            total = 0
            model.train()

            train_result = {}
            
            for oo, bimg, caps in tqdm(train_dataloader, desc='Train', leave=False):
                # print(f'{str(caps[0])}')
                train_result[f'{str(caps[0])}'] = {
                    'phrases': [],
                    'org_img':[],
                    'shape':[],
                    'cropped':[],
                    'padded':[],
                    'resized':[],
                    'output':[],
                    'c_padded':[]
                }
                
                i_caps = str(caps[0])
                bimg = bimg.squeeze()
                b_boxes, b_logits, b_phrases, b_encoded_text = predict(
                        model=dect_model,
                        image=bimg,
                        caption=i_caps,
                        box_threshold=BOX_TRESHOLD,
                        text_threshold=TEXT_TRESHOLD
                    )
                train_result[f'{str(caps[0])}']['phrases'].append(b_phrases)
                oo = oo.squeeze()
                h, w, c = oo.shape
                boxes2 = b_boxes * torch.Tensor([w, h, w, h])
                xyxy = box_convert(boxes=boxes2, in_fmt="cxcywh", out_fmt="xyxy").numpy()
                
                for xyxy_idx in range(len(xyxy)):
                    x1, y1, x2, y2 = xyxy[xyxy_idx].astype(int)
                    cropped_img = oo[y1:y2, x1:x2].permute(2,1,0)

                    cc, hh, ww = cropped_img.shape
                    train_result[f'{str(caps[0])}']['shape'].append((hh,ww))
                    nh, nw = target_size - hh, target_size - ww

                    train_result[f'{str(caps[0])}']['cropped'].append(cropped_img)
                    padd = (nw//2, nw//2, nh//2, nh//2)
                    pimg = torch.nn.functional.pad(cropped_img, padd)
                    pimg = resize(pimg, (target_size, target_size))
                    train_result[f'{str(caps[0])}']['padded'].append(pimg)
                
                for b in train_result[f'{str(caps[0])}']['padded']:
                    _, h, w = b.shape
                    s_img = resize(b, (h//down_scale, w//down_scale))
                    train_result[f'{str(caps[0])}']['resized'].append(s_img)

                txt = b_encoded_text['encoded_text']


                
                
                for re, pa, sha in zip(train_result[f'{str(caps[0])}']['resized'], train_result[str(caps[0])]['padded'],train_result[str(caps[0])]['shape'] ):
                    re, pa = re/255., pa/255.
                    re = re.unsqueeze(0)
                    pa = pa.unsqueeze(0)
                    re = re.to(device)
                    pa = pa.to(device)

                    output = model(re,txt)

                    c_output = center_crop(output, sha)
                    c_pa = center_crop(pa, sha)

                    train_result[f'{str(caps[0])}']['output'].append(c_output.detach().squeeze(0).cpu().numpy())
                    train_result[f'{str(caps[0])}']['c_padded'].append(c_pa.detach().squeeze(0).cpu().numpy())

                    loss = criterion(c_output, c_pa)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss
                    total += 1
            avg_loss = total_loss / total
            train_text = f'Train | Epoch: [{e+1}/{epochs}] |  MSE: {avg_loss} | Total: {total}'
            print('='*len(train_text))
            print(train_text)

        else:
            total_loss = 0
            total = 0
            model.eval()

            test_result = {}
            
            with torch.no_grad():
                for oo, bimg, caps in tqdm(val_dataloader, desc='Test', leave=False):

                    test_result[f'{str(caps[0])}'] = {
                        'phrases': [],
                        'org_img':[],
                        'shape':[],
                        'cropped':[],
                        'padded':[],
                        'resized':[],
                        'output':[],
                        'c_padded':[],
                        'detect':[]
                    }

                    i_caps = str(caps[0])
                    bimg = bimg.squeeze()
                    b_boxes, b_logits, b_phrases, b_encoded_text = predict(
                            model=dect_model,
                            image=bimg,
                            caption=i_caps,
                            box_threshold=BOX_TRESHOLD,
                            text_threshold=TEXT_TRESHOLD
                        )
                    test_result[f'{str(caps[0])}']['phrases'].append(b_phrases)
                    oo = oo.squeeze()
                    h, w, c = oo.shape
                    boxes2 = b_boxes * torch.Tensor([w, h, w, h])
                    xyxy = box_convert(boxes=boxes2, in_fmt="cxcywh", out_fmt="xyxy").numpy()
                    test_result[f'{str(caps[0])}']['org_img'].append(oo)
                    for xyxy_idx in range(len(xyxy)):
                        x1, y1, x2, y2 = xyxy[xyxy_idx].astype(int)
                        cropped_img = oo[y1:y2, x1:x2].permute(2,1,0)

                        cc, hh, ww = cropped_img.shape
                        
                        test_result[f'{str(caps[0])}']['shape'].append((hh,ww))
                        nh, nw = target_size - hh, target_size - ww

                        test_result[f'{str(caps[0])}']['cropped'].append(cropped_img)
                        padd = (nw//2, nw//2, nh//2, nh//2)
                        pimg = torch.nn.functional.pad(cropped_img, padd)
                        
                        pimg = resize(pimg, (target_size, target_size))
                        test_result[f'{str(caps[0])}']['padded'].append(pimg)

                        test_result[f'{str(caps[0])}']['detect'].append(resize(cropped_img, (hh//down_scale, ww//down_scale)))

                    for b in test_result[f'{str(caps[0])}']['padded']:
                        _, h, w = b.shape
                        s_img = resize(b, (h//down_scale, w//down_scale))
                        
                        test_result[f'{str(caps[0])}']['resized'].append(s_img)

                    txt = b_encoded_text['encoded_text']

                    for re, pa, sha in zip(test_result[f'{str(caps[0])}']['resized'], test_result[str(caps[0])]['padded'], test_result[str(caps[0])]['shape']):
                        re, pa = re/255., pa/255.
                        re = re.unsqueeze(0)
                        pa = pa.unsqueeze(0)
                        re = re.to(device)
                        pa = pa.to(device)

                        output = model(re,txt)

                        c_output = center_crop(output, sha)
                        c_pa = center_crop(pa, sha)

                        test_result[f'{str(caps[0])}']['output'].append(c_output.detach().squeeze(0).cpu().numpy())
                        test_result[f'{str(caps[0])}']['c_padded'].append(c_pa.detach().squeeze(0).cpu().numpy())

                        loss = criterion(output, pa)

                        total_loss += loss
                        total += 1
                avg_loss = total_loss / total
                test_text = f'Test  | Epoch: [{e+1}/{epochs}] |  MSE: {avg_loss} | Total: {total}'
                print(test_text)
                print('='*len(train_text))

                if avg_loss < best_loss:
                    print('Update results')
                    best_result = test_result
                    best_loss = avg_loss
                    best_model = deepcopy(model.state_dict())



final text_encoder_type: bert-base-uncased


Train:   3%|▎         | 27/1000 [00:17<07:29,  2.16it/s] 

In [None]:
import os
from PIL import Image
from torchvision.transforms.functional import center_crop

def save_imgs(results, root='results'):
    

    for k in results.keys():
        # root = './resultsX4'
        if not os.path.isdir(root):
            os.mkdir(root)

        if not os.path.isdir(os.path.join(root, k)):
            os.mkdir(os.path.join(root, k))
        save_path = os.path.join(root, k)
        
        org = Image.fromarray(results[k]['org_img'][0].numpy())
        org.save(os.path.join(save_path, 'org.png'))
        if len(results[k]['phrases'][0]) == 0:
            print('ups')
            continue
        zz = 0
        for i in range(len(results[k]['phrases'][0])):
            save_path = os.path.join(root, k)
            
            output = torch.tensor(results[k]['output'][i])
            output = Image.fromarray((np.clip(center_crop(output, results[k]['shape'][i]).permute(2,1,0).numpy(),0,1)*255).astype(np.uint8))
            crop = results[k]['cropped'][i].permute(2,1,0)
            detect = results[k]['detect'][i].permute(2,1,0)

            crop = Image.fromarray(crop.numpy()) #Image.fromarray((np.clip(crop.numpy(),0,1)*255).astype(np.uint8))
            detect = Image.fromarray(detect.numpy())

            if not os.path.isdir(os.path.join(save_path, f'{results[k]["phrases"][0][i]}')):
                zz = 1
                os.mkdir(os.path.join(save_path, f'{results[k]["phrases"][0][i]}'))
            save_path = os.path.join(save_path, f'{results[k]["phrases"][0][i]}')
            if os.path.isdir(save_path):
                output.save(os.path.join(save_path, f'output{zz}.png'))
                crop.save(os.path.join(save_path, f'big{zz}.png'))
                detect.save(os.path.join(save_path, f'detect{zz}.png'))
                zz += 1
            

In [None]:
down_scale

2

In [None]:
save_imgs(best_result, f'resultsX{down_scale}withDetection')

ups


In [None]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

print("PSNR :", peak_signal_noise_ratio(Test_result_img['output'][0], Test_result_img['big_image'][0]))
print("SSIM :", structural_similarity((Test_result_img['output'][0]*255).astype(np.uint8), (Test_result_img['big_image'][0]*255).astype(np.uint8), channel_axis=2,multichannel=True))