diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 04cfb16f627..3c1d616e081 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -179,7 +179,8 @@ def __init__(self, data_dir): self.data_tensors.append(lineToTensor(name)) self.labels.append(label) - #Cache the tensor representation of the labels + # Create numerical representation of labels + # Store unique labels and convert each label to its index in the label vocabulary self.labels_uniq = list(labels_set) for idx in range(len(self.labels)): temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long)