In [None]:
import os

# Set the CUDA_VISIBLE_DEVICES environment variable to a list of GPU device IDs
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


In [None]:
import pytorch_lightning as pl
import torch

In [None]:
from src.models.letr import Letr
from LETR.clean_model.backbone import build_backbone
from LETR.clean_model.transformer.transformer import build_transformer 
from LETR.clean_model.losses.losses import SetCriterion
from LETR.clean_model.utils.matcher import build_matcher
from LETR.data import build_dataset


### Initializing backbone and base


In [None]:
import argparse
from args import get_args_parser

parser = argparse.ArgumentParser(
        "LETR training and evaluation script", parents=[get_args_parser()], allow_abbrev=False
)
args, _ = parser.parse_known_args()

In [None]:
args.output_dir = "test_lightining"
args.batch_size = 1
args.coco_path = "data/wireframe_processed"
args.num_workers = 16
args.lr = 1e-4
args.dropout = 0
args.lr_drop = 200

In [None]:
num_classes = 1


In [None]:
backbone = build_backbone(args)
transformer = build_transformer(args)
matcher = build_matcher(args, type="origin_line")

criterion = SetCriterion(
        num_classes,
        eos_coef=args.eos_coef,
        args=args,
        matcher=matcher,
    )

In [None]:
dataset_train = build_dataset(image_set="train", args=args)

dataset_val = build_dataset(image_set="val", args=args)

In [None]:
sampler_train = torch.utils.data.SequentialSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
    sampler_train, args.batch_size, drop_last=True
)


In [None]:
for e in batch_sampler_train:
    break

In [None]:
from torch.utils.data import DataLoader
from helper.misc import collate_fn

In [None]:
data_loader_train = DataLoader(
    dataset_train,
    batch_sampler=batch_sampler_train,
    collate_fn=collate_fn,
    num_workers=args.num_workers,
)
data_loader_val = DataLoader(
    dataset_val,
    args.batch_size,
    sampler=sampler_val,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=args.num_workers,
)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()


In [None]:
# trainer = Trainer(accelerator="gpu", devices=[0], logger = wandb_logger)
trainer = Trainer(accelerator="gpu", devices=[0],  logger= wandb_logger, max_epochs=5)

In [None]:
model = Letr(backbone, transformer, criterion, num_classes, args.num_queries,lr_drop = args.lr_drop, batch_size = args.batch_size,lr = args.lr, aux_loss=args.aux_loss,layer1_num=args.layer1_num)
model

In [None]:
trainer.fit(model, train_dataloaders=data_loader_train, val_dataloaders=data_loader_val)