In [3]:
import numpy as np
import torch
torch.set_printoptions(edgeitems=2, threshold=50, precision=6, linewidth=80, sci_mode=False)

In [4]:
# 读取原始文本内容
with open('../data/p1ch4/jane-austen/1342-0.txt', encoding='utf8') as f:
    text = f.read()

In [5]:
# 打印查看一下文本内容
lines = text.split('\n')
line = lines[200]
line

'“Impossible, Mr. Bennet, impossible, when I am not acquainted with him'

In [6]:
# 这里128是硬编码，由于ASSII码的总长度为128
letter_t = torch.zeros(len(line), 128) # <1>
letter_t.shape

torch.Size([70, 128])

In [7]:
# 文本使用了双引号，不是有效的ASSII码，这里将其筛选出来
for i, letter in enumerate(line.lower().strip()):
    letter_index = ord(letter) if ord(letter) < 128 else 0  # <1>
    letter_t[i][letter_index] = 1

In [8]:
letter_t

tensor([[1., 0.,  ..., 0., 0.],
        [0., 0.,  ..., 0., 0.],
        ...,
        [0., 0.,  ..., 0., 0.],
        [0., 0.,  ..., 0., 0.]])

In [9]:
# 删除标点符合，返回一个单词列表
def clean_words(input_str):
    punctuation = '.,;:"!?”“_-'
    word_list = input_str.lower().replace('\n',' ').split()
    word_list = [word.strip(punctuation) for word in word_list]
    return word_list

words_in_line = clean_words(line)
line, words_in_line

('“Impossible, Mr. Bennet, impossible, when I am not acquainted with him',
 ['impossible',
  'mr',
  'bennet',
  'impossible',
  'when',
  'i',
  'am',
  'not',
  'acquainted',
  'with',
  'him'])

In [11]:
word_list = sorted(set(clean_words(text)))
word_list, len(word_list)

(['',
  '#1342]',
  '$5,000)',
  "'_she",
  "'after",
  "'ah",
  "'as-is'",
  "'bingley",
  "'had",
  "'having",
  "'i",
  "'keep",
  "'lady",
  "'lately",
  "'lydia",
  "'mr",
  "'my",
  "'oh",
  "'s",
  "'this",
  "'tis",
  "'violently",
  "'yes,'",
  "'you",
  '($1',
  '(801)',
  '(a)',
  '(an',
  '(and',
  '(any',
  '(available',
  '(b)',
  '(by',
  '(c)',
  '(comparatively',
  '(does',
  '(for',
  '(glancing',
  '(if',
  '(lady',
  '(like',
  '(most',
  '(my',
  '(or',
  '(trademark/copyright)',
  '(unasked',
  '(what',
  '(who',
  '(www.gutenberg.org)',
  '(“the',
  '*',
  '***',
  '*****',
  '1',
  '1.a',
  '1.b',
  '1.c',
  '1.d',
  '1.e',
  '1.e.1',
  '1.e.2',
  '1.e.3',
  '1.e.4',
  '1.e.5',
  '1.e.6',
  '1.e.7',
  '1.e.8',
  '1.e.9',
  '1.f',
  '1.f.1',
  '1.f.2',
  '1.f.3',
  '1.f.4',
  '1.f.5',
  '1.f.6',
  '10',
  '11',
  '12',
  '13',
  '1342-0.txt',
  '1342-0.zip',
  '14',
  '15',
  '1500',
  '15th',
  '16',
  '17',
  '18',
  '18th',
  '19',
  '1998',
  '2',
  '20',
  '

In [13]:
word2index_dict = {word: i for (i, word) in enumerate(word_list)}

len(word2index_dict), word2index_dict['impossible']

(7261, 3394)

In [8]:
word_t = torch.zeros(len(words_in_line), len(word2index_dict))
for i, word in enumerate(words_in_line):
    word_index = word2index_dict[word]
    word_t[i][word_index] = 1
    print('{:2} {:4} {}'.format(i, word_index, word))
    
print(word_t.shape)


 0 3394 impossible
 1 4305 mr
 2  813 bennet
 3 3394 impossible
 4 7078 when
 5 3315 i
 6  415 am
 7 4436 not
 8  239 acquainted
 9 7148 with
10 3215 him
torch.Size([11, 7261])


In [9]:
word_t = word_t.unsqueeze(1)
word_t.shape

torch.Size([11, 1, 7261])

In [10]:
[(c, ord(c)) for c in sorted(set(text))]


[('\n', 10),
 (' ', 32),
 ('!', 33),
 ('#', 35),
 ('$', 36),
 ('%', 37),
 ("'", 39),
 ('(', 40),
 (')', 41),
 ('*', 42),
 (',', 44),
 ('-', 45),
 ('.', 46),
 ('/', 47),
 ('0', 48),
 ('1', 49),
 ('2', 50),
 ('3', 51),
 ('4', 52),
 ('5', 53),
 ('6', 54),
 ('7', 55),
 ('8', 56),
 ('9', 57),
 (':', 58),
 (';', 59),
 ('?', 63),
 ('@', 64),
 ('A', 65),
 ('B', 66),
 ('C', 67),
 ('D', 68),
 ('E', 69),
 ('F', 70),
 ('G', 71),
 ('H', 72),
 ('I', 73),
 ('J', 74),
 ('K', 75),
 ('L', 76),
 ('M', 77),
 ('N', 78),
 ('O', 79),
 ('P', 80),
 ('Q', 81),
 ('R', 82),
 ('S', 83),
 ('T', 84),
 ('U', 85),
 ('V', 86),
 ('W', 87),
 ('X', 88),
 ('Y', 89),
 ('Z', 90),
 ('[', 91),
 (']', 93),
 ('_', 95),
 ('a', 97),
 ('b', 98),
 ('c', 99),
 ('d', 100),
 ('e', 101),
 ('f', 102),
 ('g', 103),
 ('h', 104),
 ('i', 105),
 ('j', 106),
 ('k', 107),
 ('l', 108),
 ('m', 109),
 ('n', 110),
 ('o', 111),
 ('p', 112),
 ('q', 113),
 ('r', 114),
 ('s', 115),
 ('t', 116),
 ('u', 117),
 ('v', 118),
 ('w', 119),
 ('x', 120),
 ('y',

In [11]:
ord('l')

108

#### torch.squeeze(A，N)

torch.unsqueeze()函数的作用减少数组A指定位置N的维度，如果不指定位置参数N，那么将删除所有维度为1的维度。
如果数组A的维度为（1，1，3）那么执行 torch.squeeze(A，1) 后A的维度变为 （1，3），中间的维度被删除

注：
1. 如果指定的维度大于1，那么将操作无效
2. 如果不指定维度N，那么将删除所有维度为1的维度

In [33]:
a = torch.randn(1, 2, 3)
print(f'a.shape                 {a.shape}')
print(f'torch.squeeze(a)        {torch.squeeze(a).shape}')
print(f'torch.squeeze(a, 0)     {torch.squeeze(a, 0).shape}')
print(f'torch.squeeze(a, 1)     {torch.squeeze(a, 1).shape}')
print(f'torch.squeeze(a, 2)     {torch.squeeze(a, 2).shape}')


a.shape                 torch.Size([1, 2, 3])
torch.squeeze(a)        torch.Size([2, 3])
torch.squeeze(a, 0)     torch.Size([2, 3])
torch.squeeze(a, 1)     torch.Size([1, 2, 3])
torch.squeeze(a, 2)     torch.Size([1, 2, 3])


#### torch.unsqueeze(A，N)

torch.unsqueeze()函数的作用增加数组A指定位置N的维度，例如两行三列的数组A维度为(2，3)，那么这个数组就有三个位置可以增加维度，
分别是（ [位置0] 2，[位置1] 3 [位置2] ）或者是 （ [位置-3] 2，[位置-2] 3 [位置-1] ），
如果执行 torch.unsqueeze(A，1)，数据的维度就变为了 （2，1，3）

In [38]:
a = torch.randn(2, 3)
print(f'a.shape                   {a.shape}')
print(f'torch.unsqueeze(a, 0)     {torch.unsqueeze(a, 0).shape}')
print(f'torch.unsqueeze(a, 1)     {torch.unsqueeze(a, 1).shape}')
print(f'torch.unsqueeze(a, 2)     {torch.unsqueeze(a, 2).shape}')

a.shape                   torch.Size([2, 3])
torch.unsqueeze(a, 0)     torch.Size([1, 2, 3])
torch.unsqueeze(a, 1)     torch.Size([2, 1, 3])
torch.unsqueeze(a, 2)     torch.Size([2, 3, 1])
