In [1]:

import argparse
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import accuracy_score
from models.TMC import TMC_base_channel
import torchvision.transforms as transforms
from data.aligned_conc_dataset import AlignedConcDataset
from utils.utils import *
from utils.logger import create_logger
import os
from torch.utils.data import DataLoader
import numpy as np
import torch

In [2]:



def get_args(parser):
    parser.add_argument("--batch_sz", type=int, default=32)
    parser.add_argument("--data_path", type=str, default="/home/hzhaobi/Multired/nyud2")
    parser.add_argument("--LOAD_SIZE", type=int, default=256)
    parser.add_argument("--FINE_SIZE", type=int, default=224)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=3)
    parser.add_argument("--hidden", nargs="*", type=int, default=[])
    parser.add_argument("--channel_hidden", nargs="*", type=int, default=[512])
    parser.add_argument("--channel_size", type=int, default=256)
    parser.add_argument("--channel_snr", type=int, default=-100)
    parser.add_argument("--hidden_sz", type=int, default=768)
    parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
    parser.add_argument("--img_hidden_sz", type=int, default=512)
    parser.add_argument("--include_bn", type=int, default=True)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lr_factor", type=float, default=0.3)
    parser.add_argument("--lr_patience", type=int, default=10)
    parser.add_argument("--max_epochs", type=int, default=500)
    parser.add_argument("--n_workers", type=int, default=12)
    parser.add_argument("--name", type=str, default="TMCBase_channel")
    parser.add_argument("--num_image_embeds", type=int, default=1)
    parser.add_argument("--patience", type=int, default=20)
    parser.add_argument("--savedir", type=str, default="./savepath/TMCBase_channel/nyud/")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--n_classes", type=int, default=10)
    return parser


def get_optimizer(model, args):
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    return optimizer


def get_scheduler(optimizer, args):
    return optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor
    )


def model_forward(model, args, criterion, batch):
    rgb, depth, tgt = batch['A'], batch['B'], batch['label']
    rgb, depth, tgt = rgb.cuda(), depth.cuda(), tgt.cuda()

    depth_logits, rgb_logits, depth_rgb_logits = model(rgb, depth)

    loss = criterion(depth_logits, tgt) + \
           criterion(rgb_logits, tgt) + \
           criterion(depth_rgb_logits, tgt)
    return loss, depth_logits, rgb_logits, depth_rgb_logits, tgt


@torch.no_grad()
def model_eval(data, model, args, criterion):
    model.eval()
    losses, depth_preds, rgb_preds, depthrgb_preds, tgts = [], [], [], [], []
    for batch in data:
        loss, depth_logits, rgb_logits, depth_rgb_logits, tgt = model_forward(model, args, criterion, batch)
        losses.append(loss.item())

        depth_pred = depth_logits.argmax(dim=1).cpu().numpy()
        rgb_pred = rgb_logits.argmax(dim=1).cpu().numpy()
        depth_rgb_pred = depth_rgb_logits.argmax(dim=1).cpu().numpy()

        depth_preds.append(depth_pred)
        rgb_preds.append(rgb_pred)
        depthrgb_preds.append(depth_rgb_pred)
        tgts.append(tgt.cpu().numpy())

    metrics = {"loss": np.mean(losses)}
    tgts = [l for sl in tgts for l in sl]
    depth_preds = [l for sl in depth_preds for l in sl]
    rgb_preds = [l for sl in rgb_preds for l in sl]
    depthrgb_preds = [l for sl in depthrgb_preds for l in sl]
    metrics["depth_acc"] = accuracy_score(tgts, depth_preds)
    metrics["rgb_acc"] = accuracy_score(tgts, rgb_preds)
    metrics["depthrgb_acc"] = accuracy_score(tgts, depthrgb_preds)
    return metrics


In [3]:


def train(args):
    set_seed(args.seed)
    args.savedir = os.path.join(args.savedir, args.name)
    os.makedirs(args.savedir, exist_ok=True)

    mean = [0.4951, 0.3601, 0.4587]
    std = [0.1474, 0.1950, 0.1646]
    train_transforms = [
        transforms.Resize((args.LOAD_SIZE, args.LOAD_SIZE)),
        transforms.RandomCrop((args.FINE_SIZE, args.FINE_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]
    val_transforms = [
        transforms.Resize((args.FINE_SIZE, args.FINE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ]

    train_loader = DataLoader(
        AlignedConcDataset(args, data_dir=os.path.join(args.data_path, 'train'), transform=transforms.Compose(train_transforms)),
        batch_size=args.batch_sz,
        shuffle=True,
        num_workers=args.n_workers,
    )
    test_loader = DataLoader(
        AlignedConcDataset(args, data_dir=os.path.join(args.data_path, 'test'), transform=transforms.Compose(val_transforms)),
        batch_size=args.batch_sz,
        shuffle=False,
        num_workers=args.n_workers,
    )

    model = TMC_base_channel(args).cuda()
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)
    criterion = nn.CrossEntropyLoss()

    logger = create_logger(f"{args.savedir}/logfile.log", args)

    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf

    if os.path.exists(os.path.join(args.savedir, "checkpoint.pt")):
        checkpoint = torch.load(os.path.join(args.savedir, "checkpoint.pt"))
        start_epoch = checkpoint["epoch"]
        n_no_improve = checkpoint["n_no_improve"]
        best_metric = checkpoint["best_metric"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    for i_epoch in range(start_epoch, args.max_epochs):
        model.train()
        optimizer.zero_grad()
        train_losses = []
        for batch in tqdm(train_loader, total=len(train_loader)):
            loss, *_ = model_forward(model, args, criterion, batch)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            train_losses.append(loss.item())
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        metrics = model_eval(test_loader, model, args, criterion)
        logger.info(f"Train Loss: {np.mean(train_losses):.4f}")
        log_metrics("val", metrics, logger)
        logger.info(
            f"val: Loss: {metrics['loss']:.5f} | depth_acc: {metrics['depth_acc']:.5f}, "
            f"rgb_acc: {metrics['rgb_acc']:.5f}, depth rgb acc: {metrics['depthrgb_acc']:.5f}"
        )

        tuning_metric = metrics["depthrgb_acc"]
        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric
        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
        else:
            n_no_improve += 1

        save_checkpoint(
            {
                "epoch": i_epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "n_no_improve": n_no_improve,
                "best_metric": best_metric,
            },
            is_improvement,
            args.savedir,
        )

        if n_no_improve >= args.patience:
            logger.info("No improvement. Breaking out of loop.")
            break

    load_checkpoint(model, os.path.join(args.savedir, "model_best.pt"))
    test_metrics = model_eval(test_loader, model, args, criterion)
    logger.info(
        f"Test: Loss: {test_metrics['loss']:.5f} | depth_acc: {test_metrics['depth_acc']:.5f}, "
        f"rgb_acc: {test_metrics['rgb_acc']:.5f}, depth rgb acc: {test_metrics['depthrgb_acc']:.5f}"
    )
    log_metrics("Test", test_metrics, logger)



In [4]:
parser = argparse.ArgumentParser(description="Train TMC Base Model")
get_args(parser)
args, remaining_args = parser.parse_known_args(args = [])
assert remaining_args == [], remaining_args

In [5]:
set_seed(args.seed)
args.savedir = os.path.join(args.savedir, args.name)
os.makedirs(args.savedir, exist_ok=True)

mean = [0.4951, 0.3601, 0.4587]
std = [0.1474, 0.1950, 0.1646]
train_transforms = [
    transforms.Resize((args.LOAD_SIZE, args.LOAD_SIZE)),
    transforms.RandomCrop((args.FINE_SIZE, args.FINE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
]
val_transforms = [
    transforms.Resize((args.FINE_SIZE, args.FINE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
]

train_loader = DataLoader(
    AlignedConcDataset(args, data_dir=os.path.join(args.data_path, 'train'), transform=transforms.Compose(train_transforms)),
    batch_size=args.batch_sz,
    shuffle=True,
    num_workers=args.n_workers,
)
test_loader = DataLoader(
    AlignedConcDataset(args, data_dir=os.path.join(args.data_path, 'test'), transform=transforms.Compose(val_transforms)),
    batch_size=args.batch_sz,
    shuffle=False,
    num_workers=args.n_workers,
)

model = TMC_base_channel(args).cuda()
optimizer = get_optimizer(model, args)
scheduler = get_scheduler(optimizer, args)
criterion = nn.CrossEntropyLoss()

logger = create_logger(f"{args.savedir}/logfile.log", args)

start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf

if os.path.exists(os.path.join(args.savedir, "checkpoint.pt")):
    checkpoint = torch.load(os.path.join(args.savedir, "checkpoint.pt"))
    start_epoch = checkpoint["epoch"]
    n_no_improve = checkpoint["n_no_improve"]
    best_metric = checkpoint["best_metric"]
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])




INFO - 08/08/25 16:11:45 - 0:00:00 - FINE_SIZE: 224
                                     LOAD_SIZE: 256
                                     batch_sz: 32
                                     channel_hidden: [512]
                                     channel_size: 256
                                     channel_snr: -100
                                     data_path: /home/hzhaobi/Multired/nyud2
                                     dropout: 0.1
                                     gradient_accumulation_steps: 3
                                     hidden: []
                                     hidden_sz: 768
                                     img_embed_pool_type: avg
                                     img_hidden_sz: 512
                                     include_bn: True
                                     lr: 0.0001
                                     lr_factor: 0.3
                                     lr_patience: 10
                                     max_epochs: 500
    

In [6]:
model.train()
optimizer.zero_grad()
train_losses = []
for batch in tqdm(train_loader, total=len(train_loader)):
    loss, depth_logits, rgb_logits, depth_rgb_logits, tgt = model_forward(model, args, criterion, batch)
    break

  0%|          | 0/25 [00:02<?, ?it/s]


NotImplementedError: Module [ModuleList] is missing the required "forward" function

In [None]:
rgb, depth, tgt = batch['A'], batch['B'], batch['label']
rgb, depth, tgt = rgb.cuda(), depth.cuda(), tgt.cuda()

depth_logits, rgb_logits, depth_rgb_logits = model(rgb, depth)

In [7]:
depth_feat = model.depthenc(depth)
depth_feat = model.depthchannel_enc(depth_feat)
depth_feat = model.channel(depth_feat, -100)
depth_feat = torch.flatten(depth_feat, start_dim=1)

rgb_feat = model.rgbenc(rgb)
rgb_feat = model.rgbchannel_enc(rgb_feat)
rgb_feat = model.channel(rgb_feat, -100)
rgb_feat = torch.flatten(rgb_feat, start_dim=1)

# individual modality logits
depth_logits = model._forward_mlp(depth_feat, model.clf_depth)
rgb_logits = model._forward_mlp(rgb_feat, model.clf_rgb)

NameError: name 'depth' is not defined