In [None]:
import torch
import argparse
import os
import numpy as np
import yaml
import random
from tqdm import tqdm
import torchvision
from infer import evaluate_map_v2, infer_from_model
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from dataset.st import SceneTextDataset
from torch.utils.data.dataloader import DataLoader
from detection.transform import GeneralizedRCNNTransform, resize_boxes
import detection
from detection.faster_rcnn import FastRCNNPredictor
from detection.anchor_utils import AnchorGenerator

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Collate function
def collate_function(data):
    return tuple(zip(*data))

def draw_and_save_boxes(image, boxes, labels, scores, output_dir, img_idx):
    plt.figure(figsize=(5, 5))
    plt.imshow(image.permute(1, 2, 0))
    ax = plt.gca()
    for box, label, score in zip(boxes, labels, scores):
        x1, y1, x2, y2 = box
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='red', linewidth=2))
        ax.text(x1, y1, f'{label} {score:.2f}', fontsize=8, color='red')
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, f'img_{img_idx}_boxes.png'))
    plt.close()

def save_objectness_heatmap(objectness_scores, base_dir=""):
    num_levels = len(objectness_scores[0])  
    num_imgs = len(objectness_scores)

    os.makedirs(base_dir, exist_ok=True)

    for level in range(num_levels):
        level_dir = os.path.join(base_dir, f"level_{level}")
        os.makedirs(level_dir, exist_ok=True)

        for img_idx in range(num_imgs):
            heatmap = objectness_scores[img_idx][level][0]  # Select first anchor channel

            plt.figure(figsize=(5, 5))
            plt.imshow(heatmap, cmap='hot', interpolation='nearest')
            plt.title(f"Level {level} - Image {img_idx}")
            plt.axis("off")

            img_path = os.path.join(level_dir, f"img_{img_idx}_heatmap.png")
            print(f"Saving heatmap to {img_path}")
            plt.savefig(img_path)
            plt.close()

def train(args, output_dir='output'):
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
            return
    
    print(config)
    os.makedirs(output_dir, exist_ok=True)
    
    dataset_config = config['dataset_params']
    train_config = config['train_params']
    
    seed = train_config['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(seed)
    
    # Load dataset
    st = SceneTextDataset('train', root_dir=dataset_config['root_dir'])
    train_indices = np.random.choice(len(st), len(st) - 4, replace=False)
    test_indices = np.setdiff1d(np.arange(len(st)), train_indices)
    
    train_data = torch.utils.data.Subset(st, train_indices)
    test_data = torch.utils.data.Subset(st, test_indices)
    
    train_dataset = DataLoader(train_data, batch_size=4, shuffle=False, num_workers=0, collate_fn=collate_function)
    test_dataset = DataLoader(test_data, batch_size=4, shuffle=False, num_workers=0, collate_fn=collate_function)
    
    # Model setup
    faster_rcnn_model = detection.fasterrcnn_resnet50_fpn(pretrained_backbone=True, min_size=600, max_size=1000, thresholded_thetas=True, num_thetas=10)
    faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(
        faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features,
        num_classes=dataset_config['num_classes'],
        thresholded_thetas=True,
        num_thetas=10
    )
    
    faster_rcnn_model.to(device)
    os.makedirs(train_config['task_name'], exist_ok=True)
    
    optimizer = torch.optim.SGD(
        lr=1E-4, params=filter(lambda p: p.requires_grad, faster_rcnn_model.parameters()),
        weight_decay=5E-5, momentum=0.9
    )
    
    num_epochs = train_config['num_epochs']
    num_epochs = 50
    faster_rcnn_model.train()
    for i in range(num_epochs):
        rpn_classification_losses = []
        rpn_localization_losses = []
        frcnn_classification_losses = []
        frcnn_localization_losses = []
        

        epoch_dir = os.path.join(output_dir, 'heatmap_frames', f'epoch_{i}')
        os.makedirs(epoch_dir, exist_ok=True)
        
        os.makedirs(os.path.join(output_dir, f'epoch_{i}'), exist_ok=True)
        os.makedirs(os.path.join(epoch_dir, 'test'), exist_ok=True)
        os.makedirs(os.path.join(epoch_dir, 'train'), exist_ok=True)
        infer_from_model(faster_rcnn_model, test_dataset, os.path.join(epoch_dir, 'test'))
        mean_ap_test, pre_test, rec_test = evaluate_map_v2(faster_rcnn_model, test_dataset)
        
        infer_from_model(faster_rcnn_model, train_dataset, os.path.join(epoch_dir, 'train'))
        mean_ap_train, pre_train, rec_train = evaluate_map_v2(faster_rcnn_model, train_dataset)
        
        # Save AP values
        outfile_test = os.path.join(output_dir, f'epoch_{i}.txt')
        outfile_train = os.path.join(output_dir, f'epoch_train_{i}.txt')
        print(pre_train, pre_test)
        with open(outfile_test, 'a') as f:
            ctr = 0
            for j in np.linspace(0.05, 0.95, 10):
                avg_pre = np.mean(pre_test[ctr])
                avg_rec = np.mean(rec_test[ctr])
                f.write(f'threshold = {j}\n Precision = {avg_pre:.4f} | Recall = {avg_rec:.4f}\n mean AP = {mean_ap_test[ctr]:.4f}\n')
                ctr += 1
        with open(outfile_train, 'a') as f:
            ctr = 0
            for j in np.linspace(0.05, 0.95, 10):
                avg_pre = np.mean(pre_train[ctr])
                avg_rec = np.mean(rec_train[ctr])
                f.write(f'threshold = {j}\n Precision = {avg_pre:.4f} | Recall = {avg_rec:.4f}\n mean AP = {mean_ap_test[ctr]:.4f}\n')
                ctr += 1
        
        print(f"Saved test AP to {outfile_test}")
        print(f"Saved train AP to {outfile_train}")
        
        #get objectness heatmap
        epoch_dir = os.path.join(output_dir, 'heatmap_frames', f'epoch_{i}')
        for ims, targets, _ in test_dataset:
            images = [im.float().to(device) for im in ims]
            with torch.no_grad():
                transform = GeneralizedRCNNTransform(min_size=800, max_size=1333, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])
                image_list, _ = transform(images)

                backbone_features = faster_rcnn_model.backbone(image_list.tensors)
                
                # Extract RPN objectness scores
                features_list = list(backbone_features.values())  # Convert OrderedDict to list of tensors
                proposals, proposal_losses, _ = faster_rcnn_model.rpn(image_list, features_list, targets)
                proposals = resize_boxes(proposals, image_list.image_sizes, images[0].shape[-2:])
                
                rpn_logits, _ = faster_rcnn_model.rpn.head(features_list)
                # Convert logits to probabilities
                objectness_scores = [logit.sigmoid().cpu().numpy() for logit in rpn_logits]
                save_objectness_heatmap(objectness_scores, epoch_dir)
                
    

        for ims, targets, _ in tqdm(train_dataset):
            optimizer.zero_grad()
            images = [im.float().to(device) for im in ims]
            
            for target in targets:
                target['boxes'] = target['bboxes'].float().to(device)
                target['thetas'] = target['thetas'].float().to(device)
                del target['bboxes']
                target['labels'] = target['labels'].long().to(device)
            
            images = [im.float().to(device) for im in ims]
            batch_losses = faster_rcnn_model(images, targets)
            loss = sum(batch_losses.values())
            
            rpn_classification_losses.append(batch_losses['loss_objectness'].item())
            rpn_localization_losses.append(batch_losses['loss_rpn_box_reg'].item())
            frcnn_classification_losses.append(batch_losses['loss_classifier'].item())
            frcnn_localization_losses.append(batch_losses['loss_box_reg'].item())
            
            loss.backward()
            optimizer.step()
        
        print(f'Finished epoch {i}')
        torch.save(
            faster_rcnn_model.state_dict(),
            os.path.join(output_dir, f'tv_frcnn_r50fpn_{train_config["ckpt_name"]}')
        )
        
        loss_output = (
            f"RPN Classification Loss: {np.mean(rpn_classification_losses):.4f} | "
            f"RPN Localization Loss: {np.mean(rpn_localization_losses):.4f} | "
            f"FRCNN Classification Loss: {np.mean(frcnn_classification_losses):.4f} | "
            f"FRCNN Localization Loss: {np.mean(frcnn_localization_losses):.4f}"
        )
        print(loss_output)
    
    print('Done Training...')

class Args:
    config_path = 'config/st.yaml'

args = Args()
train(args, output_dir='output')

SceneTextDataset
{'dataset_params': {'root_dir': 'Q1', 'num_classes': 2}, 'model_params': {'im_channels': 3, 'aspect_ratios': [0.5, 1, 2], 'scales': [128, 256, 512], 'min_im_size': 600, 'max_im_size': 1000, 'backbone_out_channels': 512, 'fc_inner_dim': 1024, 'rpn_bg_threshold': 0.3, 'rpn_fg_threshold': 0.7, 'rpn_nms_threshold': 0.7, 'rpn_train_prenms_topk': 12000, 'rpn_test_prenms_topk': 6000, 'rpn_train_topk': 2000, 'rpn_test_topk': 300, 'rpn_batch_size': 256, 'rpn_pos_fraction': 0.5, 'roi_iou_threshold': 0.5, 'roi_low_bg_iou': 0.1, 'roi_pool_size': 7, 'roi_nms_threshold': 0.3, 'roi_topk_detections': 100, 'roi_score_threshold': 0.05, 'roi_batch_size': 128, 'roi_pos_fraction': 0.25}, 'train_params': {'task_name': 'st', 'seed': 1111, 'infer_seed': 1122, 'acc_steps': 1, 'num_epochs': 100, 'lr_steps': [12, 16], 'lr': 0.001, 'ckpt_name': 'faster_rcnn_st.pth'}}




thresholded_thetas in pred False
OrientedBoxCoder
(10.0, 10.0, 5.0, 5.0, 1.0)
thresholded_thetas in pred True


  0%|          | 0/250 [00:01<?, ?it/s]


ValueError: Expected input batch_size (4096) to match target batch_size (2048).