In [None]:
import torch
from PIL import Image
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet18_Weights
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import json
import torch.optim as optim
import os
from torch.nn.utils.rnn import pad_sequence
import random
from torch.utils.data import random_split
from tqdm import tqdm


In [None]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    print("fuck")

In [None]:
class ResnetBackbone(nn.Module): # get C2, C3, C4, and C5
    def __init__(self, backbone='resnet18', pretrained=True):
        super(ResnetBackbone, self).__init__()

        # feature extractor
        self.backbone = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        print(self.backbone)
        self.feature_channels=[64, 128, 256, 512] # dimension for c1, c2, c3, c4 in resnet 18

        # remove last two layers of resnet18
        model_layers = list(self.backbone.children())[:-2]
        self.backbone = nn.Sequential(*model_layers)
        #print(self.backbone) # view resnet18 model

        self.stem = nn.Sequential(                  #contains first residual block of resnet
            self.backbone[0],                       # conv1
            self.backbone[1],                       # batch norm
            self.backbone[2],                       # ReLU
            self.backbone[3]                        # max pooling
        )

        self.layer1 = self.backbone[4]
        self.layer2 = self.backbone[5]
        self.layer3 = self.backbone[6]
        self.layer4 = self.backbone[7]

    def forward(self, x):
        c1 = self.stem(x)
        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)
        return c2, c3, c4, c5

In [None]:
temp = ResnetBackbone()

In [None]:
class FPN(nn.Module):
    def __init__(self, backbone, out_channel=256): # out_channel represents fixed number of features outputs for C5 - C2
        super(FPN, self).__init__()
        self.backbone = backbone

        # lateral convolution, merges C# -> M#
        self.lateral_layers = nn.ModuleList([
            nn.Conv2d(self.backbone.feature_channels[3], out_channel, kernel_size=1), # feature channel corresponds to layers in ResnetBackbone, eg. feature_channel[3] = C5
            nn.Conv2d(self.backbone.feature_channels[2], out_channel, kernel_size=1),
            nn.Conv2d(self.backbone.feature_channels[1], out_channel, kernel_size=1),
            nn.Conv2d(self.backbone.feature_channels[0], out_channel, kernel_size=1)   
        ])

        # smoothing 3x3 filters, smooths M4, M3, M2 -> P4, P3, P2
        self.smoothing_layers = nn.ModuleList([         #module list stores sub modules like Sequential but layer sequence can be changed in forward pass
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
        ])

    def forward(self, x):
        c2, c3, c4, c5 = self.backbone(x) #bottom up path from resnet backbone

        m5 = self.lateral_layers[0](c5) #C5 -> M5, skip upsampling
        m4 = self.lateral_layers[1](c4) + F.interpolate(m5, size=c4.shape[2:], mode='nearest') #include smoothing and upsampling
        m3 = self.lateral_layers[2](c3) + F.interpolate(m4, size=c3.shape[2:], mode='nearest')
        m2 = self.lateral_layers[3](c2) + F.interpolate(m3, size=c2.shape[2:], mode='nearest')


        p5 = m5
        p4 = self.smoothing_layers[0](m4)
        p3 = self.smoothing_layers[1](m3)
        p2 = self.smoothing_layers[2](m2)

        return [p2, p3, p4, p5]

In [None]:
class DetectionHead(nn.Module):
    def __init__(self, in_channels=256, type_classes=3, legibility_classes=2):
        super(DetectionHead, self).__init__()
        #shared convolution layer

        self.shared_conv = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        # classification and regression heads

        self.type_head = nn.Conv2d(256, type_classes, kernel_size=1)
        self.bbox_head = nn.Conv2d(256, 4, kernel_size=1)
        self.legibility_head = nn.Conv2d(256, legibility_classes, kernel_size=1)


    def forward(self, feature_maps):
        type_preds = []
        bbox_preds = []
        legibility_preds = []

        for feature_map in feature_maps:
            x = self.shared_conv(feature_map)
            type_preds.append(self.type_head(x))
            bbox_preds.append(self.bbox_head(x))
            legibility_preds.append(self.legibility_head(x))

        predictions = {
            "type": type_preds,
            "bbox": bbox_preds,
            "legibility": legibility_preds
        }

        return predictions



In [None]:
class TextDetectionModel(nn.Module):
    def __init__(self, type_classes=2, legibility_classes=2):
        super(TextDetectionModel, self).__init__(),

        self.backbone = ResnetBackbone()
        self.FPN = FPN(self.backbone)
        self.detection_head = DetectionHead(
            in_channels=256,
            type_classes=type_classes,
            legibility_classes=legibility_classes
        )
        
    def forward(self, x):
        fpn_features = self.FPN(x)

        predictions = self.detection_head(fpn_features)

        return predictions
        

In [None]:
class COCODataset(Dataset):
    def __init__(self, annotation_file, image_dir, transform=None):
        with open(annotation_file, 'r') as ann_file:
            self.data = json.load(ann_file)

        self.image_dir = image_dir
        self.transform = transform

        self.image_ids = list(self.data['imgs'].keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_metadata = self.data['imgs'][img_id]

        annotations = [
            ann for ann in self.data['anns'].values()
            if ann['image_id'] == int(img_id)
        ]

        img_path = os.path.join(self.image_dir, img_metadata['file_name'])
        image = Image.open(img_path).convert("RGB")
        
        width, height = image.size
        
        scale_w = 256 / width
        scale_h = 256 / height

        if self.transform:
            image = self.transform(image)

        bboxes = []
        type_labels = []
        legibility_labels = []

        for ann in annotations:
            bbox = ann.get('bbox', [0, 0, 0, 0])
            category = ann.get('class', "unknown")  # default to "unknown"
            legibility = ann.get('legibility', "illegible")
            
            
            
            if len(bbox) == 4:
                x1, y1, w, h = bbox
                scaled_bbox = [
                    x1 * scale_w,
                    y1 * scale_h,
                    (x1 + w) * scale_w,
                    (y1 + h) * scale_h
                ]
                bboxes.append(scaled_bbox)
            else:
                bboxes.append([0, 0, 0, 0])  # placeholder

            if category == "machine printed":
                type_labels.append(1)  # 1 represents "machine printed"
            elif category == "handwritten":
                type_labels.append(2)  # 2 represents "handwritten"
            elif category == "others":
                type_labels.append(3)
            else:
                type_labels.append(0)  # 0 for unknown

            if legibility == "legible":
                legibility_labels.append(1)  # 1 for legible
            else:
                legibility_labels.append(0)  # 2 for illegible

        return image, {
            'bboxes': torch.tensor(bboxes, dtype=torch.float32),
            'type_labels': torch.tensor(type_labels, dtype=torch.int64),
            'legibility_labels': torch.tensor(legibility_labels, dtype=torch.int64),
            'image_id': img_id
        }

In [None]:
def collate_fn(batch):
    images = []
    bboxes = []
    type_labels = []
    legibility_labels = []
    image_ids = []

    for image, annotation in batch:
        images.append(image)
        
        if len(annotation['bboxes']) == 0:
            annotation['bboxes'] = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32)
            annotation['type_labels'] = torch.tensor([0], dtype=torch.int64)
            annotation['legibility_labels'] = torch.tensor([0], dtype=torch.int64)
        
        bboxes.append(annotation['bboxes'])
        type_labels.append(annotation['type_labels'])
        legibility_labels.append(annotation['legibility_labels'])
        image_ids.append(annotation['image_id'])

    images = torch.stack(images)

    bboxes_padded = pad_sequence(bboxes, batch_first=True, padding_value=-1)
    type_labels_padded = pad_sequence(type_labels, batch_first=True, padding_value=-1)
    legibility_labels_padded = pad_sequence(legibility_labels, batch_first=True, padding_value=-1)

    return {
        'images': images,
        'bboxes': bboxes_padded,
        'type_labels': type_labels_padded,
        'legibility_labels': legibility_labels_padded,
        'image_ids': image_ids
    }


In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

annotation_file = "./COCO-Text/COCO_Text.json"
image_dir = "./COCO-Text/train2014"

dataset = COCODataset(annotation_file=annotation_file, image_dir=image_dir, transform=transform)
dataset_size = len(dataset)
train_size = 0.8 * dataset_size

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42)) #ensure same split everytime
print(dataset_size)
print(len(train_dataset))
print(len(val_dataset))

for i in range(20):
    image, annotations = train_dataset[i]
    
    print(f"Example {i + 1}")
    print("Image shape:", image.shape)
    print("Bounding Boxes:", annotations['bboxes'])
    print("Type Labels:", annotations['type_labels'])
    print("Legibility Labels:", annotations['legibility_labels'])
    print("Image ID:", annotations['image_id'])
    print("-" * 50)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
counter = 0
index = random.randrange(8)
for batch_idx, batch in enumerate(val_dataloader):
    print(index)
    print(f"Batch {batch_idx + 1}")
    print("Images Shape:", batch['images'].shape)
    print("Bboxes Shape:", batch['bboxes'].shape)
    print("Type Labels Shape:", batch['type_labels'].shape)
    print("Legibility Labels Shape:", batch['legibility_labels'].shape)
    print("Image IDs:", batch['image_ids'])
    print("-" * 50)
    
    print("Details of the first entry in the batch:")
    print("Image shape:", batch['images'][index].shape)
    print("Bounding Boxes:", batch['bboxes'][index])
    print("Type Labels:", batch['type_labels'][index])
    print("Legibility Labels:", batch['legibility_labels'][index])
    print("Image ID:", batch['image_ids'][index])
    print("=" * 50)
    
    counter += 1
    if counter >=1:
        break


In [None]:
def anchor_free_assign_targets(predictions, bboxes, type_labels, legibility_labels, image_sizes, strides=[4, 8, 16, 32]):
    # To Be Implemented
    
    pass


In [None]:
model = TextDetectionModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
bbox_loss_fn = nn.MSELoss()  # placeholder loss functions
type_loss_fn = nn.CrossEntropyLoss() 
legibility_loss_fn = nn.CrossEntropyLoss()
