Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP training question #1096

Open
Henryplay opened this issue Nov 25, 2022 · 2 comments
Open

DDP training question #1096

Henryplay opened this issue Nov 25, 2022 · 2 comments

Comments

@Henryplay
Copy link

Hi, I'm using the tutorial https://github.com/pytorch/tutorials/blob/master/intermediate_source/ddp_tutorial.rst for DDP train,using 4 gpus in myself code, reference Basic Use Case. But when I finished the modification, it was stuck during run the demo,meanwhile,video memory has been occupied.Could you help me?

@Henryplay
Copy link
Author

and my code is here

from math import gamma
import os
import torch
import argparse
from tqdm import tqdm
from utils.scheduler import GradualWarmupScheduler
from modeling.model import CNN
from modeling.loss import CTCLoss
from utils.dataset import CharDict, LoadData, ImageTransform
from utils.utils import paser_config, edit_distance_score, setup_logger
from torch.utils.data import DataLoader

import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

class Trainer:

    def __init__(self, config_file):
        self.configs = paser_config(config_file)
        # os.environ['CUDA_VISIBLE_DEVICES'] = self.configs['trainer']['gpus']
        self.build_dataloader()
        self.build_model()
        self.start_epoch = 0
        self.max_epochs = self.configs['trainer']['epochs']
        self.save_dir = os.path.join(self.configs['trainer']['output_dir'], self.configs['name'])
        if not os.path.exists(self.save_dir) : os.makedirs(self.save_dir)
        log_file_mode = 'a' if self.configs['trainer']["resume_ckpt"] else 'w'
        self.logger = setup_logger(log_file_path=os.path.join(self.save_dir, 'train.log'), log_file_mode=log_file_mode)
        self.checkpoint = {
            'epoch': 0,
            'history_acc': [],
            'history_eds': [],
            'model': {},
            'optimizer': {},
            'lr_scheduler': {},
            'configs': self.configs
        }
        if self.configs['trainer']["finetune_ckpt"]:
            self.model.load_state_dict(torch.load(self.configs['trainer']["finetune_ckpt"])['model'], False)
            #ckpt = torch.load(self.configs['trainer']["finetune_ckpt"])['model']
            #self.model.load_state_dict({k: v for k, v in ckpt.items() if 'fc' not in k},False)
        elif self.configs['trainer']["resume_ckpt"]:
            self.checkpoint = torch.load(self.configs['trainer']["resume_ckpt"])
            self.model.load_state_dict(self.checkpoint['model'])
            self.optimizer.load_state_dict(self.checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(self.checkpoint['lr_scheduler'])
            self.checkpoint['model'].clear()
            self.checkpoint['optimizer'].clear()
            self.checkpoint['lr_scheduler'].clear()
            self.start_epoch = self.checkpoint['epoch'] + 1
        # warp dp-model
        # self.model = torch.nn.DataParallel(self.model)
    def setup(self,rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    def cleanup(self):
        dist.destroy_process_group()
    def train(self,rank,world_size):
        self.setup(rank,world_size)
        self.model = self.model.to(rank)
        self.model = DDP(self.model, device_ids=[rank])
        for epoch in range(self.start_epoch, self.max_epochs):
            self.model.train()
            self.checkpoint['epoch'] = epoch
            for i, datas in enumerate(self.train_dataloader):
                img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
                img = img.to(rank)
                preds = self.model(img)
                loss = self.criterion(preds, targets.to(rank), target_lens.to(rank))
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                # log info
                if i%10 == 0:
                    batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
                    msg = "Epoch: %d/%d, " % (epoch, self.max_epochs) + \
                          "Batch: %d/%d, "%(i, len(self.train_dataloader)) + \
                          "Lr: %.6f, " %  self.scheduler_warmup.get_last_lr()[0] + \
                          "Loss: %.3f, " % loss.item() + \
                          "Acc: %.3f, EDS: %.3f" % (batch_acc, batch_eds)
                    self.logger.info(msg)
            self.scheduler_warmup.step()
            self.cleanup()
            self.eval()

    @torch.no_grad()
    def eval(self):
        self.model.eval()
        nbatch = len(self.test_dataloader)
        acc, eds = 0, 0
        for datas in tqdm(self.test_dataloader, desc="Testing..."):
            img, targets, target_lens = datas["img"], datas["target"], datas["target_len"]
            preds = self.model(img.cuda())
            batch_acc, batch_eds = self.metrics(preds, targets, target_lens)
            acc += batch_acc
            eds += batch_eds
        mean_acc = acc / nbatch
        mean_eds = eds / nbatch

        self.save_model(mean_acc, mean_eds)
        return mean_acc, mean_eds

    def metrics(self, preds, targets, target_lens):
        """WARNING:
            This function will consume a lot of time. Don't use it frequently.
        """
        bs = preds.size(0)
        preds_prob,  preds_idx = preds.permute(0,2,1).detach().softmax(dim=2).max(2)
        decode_idx, decode_prob,_ = self.chardict.ctc_decode(preds_idx.cpu().numpy(), preds_prob.cpu().numpy())
        preds_texts = [self.chardict.idx2text(i, reserve_char='\a') for i in decode_idx]
        target_texts = [self.chardict.idx2text(t[:l], reserve_char='') for t, l in zip(targets, target_lens)]
        ed_score = 0.0
        n_correct = 0
        for s1, s2 in zip(preds_texts, target_texts):
            ed_score += edit_distance_score(s1, s2)
            n_correct += (s1 == s2)
        ed_score /= bs
        batch_acc = n_correct / bs
        return batch_acc, ed_score

    def save_model(self, cur_acc, cur_eds):
        best_acc_path = os.path.join(self.save_dir, "model_best_acc.pth")
        best_eds_path = os.path.join(self.save_dir, "model_best_eds.pth")
        model_last_path = os.path.join(self.save_dir, "model_last.pth")
        self.checkpoint['history_acc'].append(cur_acc)
        self.checkpoint['history_eds'].append(cur_eds)
        self.checkpoint['model'] = self.model.module.state_dict()
        self.checkpoint['optimizer'] = self.optimizer.state_dict()
        self.checkpoint['lr_scheduler'] = self.lr_scheduler.state_dict()

        torch.save(self.checkpoint, model_last_path)
        self.logger.info("Current acc: %.3f, eds: %.3f" % (cur_acc, cur_eds))
        self.logger.info("Save current epoch model to: %s" % model_last_path)
        best_acc = max(self.checkpoint['history_acc'])
        best_eds = max(self.checkpoint['history_eds'])
        if cur_acc >= best_acc:
            torch.save(self.checkpoint, best_acc_path)
            self.logger.info("Best acc: %.3f", cur_acc)
            self.logger.info("Save best Acc model to: %s" % best_acc_path)
        if cur_eds >= best_eds:
            torch.save(self.checkpoint, best_eds_path)
            self.logger.info("Best eds: %.3f", cur_eds)
            self.logger.info("Save best EDS model to: %s" % best_eds_path)

        # release
        self.checkpoint['model'].clear()
        self.checkpoint['optimizer'].clear()
        self.checkpoint['lr_scheduler'].clear()

    def build_model(self):
        in_dim = 1 if self.configs['dataset']['img_mode'] == 'gray' else 3
        out_dim = self.configs['dataset']['ncls']
        self.model = CNN(in_dim, out_dim)
        self.optimizer = getattr(torch.optim, self.configs['optimizer']['type'])(
            self.model.parameters(), **self.configs['optimizer']['args'])
        #set lr_decay
        lr_scheduler_type = self.configs['lr_scheduler']['type']
        if lr_scheduler_type == "StepLR":
            self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.configs['lr_scheduler']['type'])(
                self.optimizer, **self.configs['lr_scheduler']['args'])
        else:
            self.lr_scheduler = getattr(torch.optim.lr_scheduler,self.configs['lr_scheduler']['type'])(
                self.optimizer,5
            )
        self.criterion = CTCLoss()

    def build_dataloader(self):
        self.chardict = CharDict(
            self.configs['dataset']['dict'], self.configs['dataset']['ncls'])
        imtrans = ImageTransform(
            self.configs['dataset']['img_mode'], self.configs['dataset']['img_size'])
        trainset = LoadData(
            self.configs['dataset']['trainset'], self.chardict, imtrans)
        self.train_dataloader = DataLoader(
            trainset, self.configs['dataset']['batch_size'], shuffle=True, collate_fn=trainset.collate_fn, num_workers=16)
        testset = LoadData(
            self.configs['dataset']['testset'], self.chardict, imtrans)
        self.test_dataloader = DataLoader(
            testset, self.configs['dataset']['batch_size'], shuffle=False, collate_fn=trainset.collate_fn, num_workers=16)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file', default='config/pycrnn.yaml', type=str)
    args = parser.parse_args()
    trainer = Trainer(args.config_file)
    world_size = 4
    mp.spawn(trainer.train,
            args=(world_size, ),
            nprocs = world_size,
            join=True)

@AntyRia
Copy link

AntyRia commented Aug 24, 2023

Hi, do you have a problem with the application getting stuck after starting multiple nodes? On my side, too, running the official multi-node example would get stuck

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants