diff --git a/arg_parser.py b/arg_parser.py index 0bdea97..d8d8fde 100644 --- a/arg_parser.py +++ b/arg_parser.py @@ -19,7 +19,7 @@ def __init__(self): self.parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducible outputs.') # System configurations - self.parser.add_argument('--gpu_ids', type=str, default='0' if torch.cuda.is_available() else '-1', help='Comma-separated list of GPU IDs. Use -1 for CPU.') + self.parser.add_argument('--gpu_ids', type=str, default='0,1' if torch.cuda.is_available() else '-1', help='Comma-separated list of GPU IDs. Use -1 for CPU.') self.parser.add_argument('--num_workers', default=4, type=int, help='Number of threads for the DataLoader.') self.parser.add_argument('--dataset_root', type=str, default=join(dirname(argv[0]), 'data'), help='The root of the dataset directory') diff --git a/model.py b/model.py index ef66da0..0254f3e 100644 --- a/model.py +++ b/model.py @@ -26,12 +26,6 @@ def get_model(model_load_path): } model = ResNet1d(**model_args) - - # TODO because the original model authors did not register parameters, - # correctly with nn.ModuleList, it is impossible to be able to - # simultaneously load the previous model check points, and run on multiple - # GPUs. This should be remedied a la - # https://discuss.pytorch.org/t/discrepancy-between-manual-parameter-registration-vs-using-nn-modulelist-when-parallelizing/181055 model.load_state_dict(torch.load(model_load_path, map_location='cpu')['model']) # The model originally took 12 channels as input (corresponding to a 12 @@ -175,14 +169,14 @@ def __init__(self, input_dim, blocks_dim, n_classes, kernel_size=17, dropout_rat self.bn1 = nn.BatchNorm1d(n_filters_out) # Residual block layers - self.res_blocks = [] for i, (n_filters, n_samples) in enumerate(blocks_dim): n_filters_in, n_filters_out = n_filters_out, n_filters n_samples_in, n_samples_out = n_samples_out, n_samples downsample = _downsample(n_samples_in, n_samples_out) - resblk1d = ResBlock1d(n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate) - self.add_module('resblock1d_{0}'.format(i), resblk1d) - self.res_blocks += [resblk1d] + setattr(self, f'resblock1d_{i}', ResBlock1d( + n_filters_in, n_filters_out, downsample, + kernel_size, dropout_rate + )) # Linear layer n_filters_last, n_samples_last = blocks_dim[-1] @@ -198,8 +192,8 @@ def forward(self, x): # Residual blocks y = x - for blk in self.res_blocks: - x, y = blk(x, y) + for i in range(self.n_blk): + x, y = getattr(self, f'resblock1d_{i}')(x, y) # Flatten array x = x.view(x.size(0), -1)