diff --git a/src/main/scala/intellij/haskell/editor/HaskellLiveTemplate.scala b/src/main/scala/intellij/haskell/editor/HaskellLiveTemplate.scala index 39c6ec664..5d31f13ff 100644 --- a/src/main/scala/intellij/haskell/editor/HaskellLiveTemplate.scala +++ b/src/main/scala/intellij/haskell/editor/HaskellLiveTemplate.scala @@ -5,7 +5,7 @@ import com.intellij.codeInsight.template.impl.DefaultLiveTemplatesProvider import com.intellij.psi.PsiFile import com.intellij.psi.util.PsiTreeUtil import intellij.haskell.HaskellFileType -import intellij.haskell.psi.{HaskellExpression, HaskellFileHeader, HaskellModuleBody} +import intellij.haskell.psi.{HaskellExpression, HaskellFileHeader} class HaskellTemplateContextType extends TemplateContextType("HASKELL_FILE", "Haskell") { override def isInContext(file: PsiFile, offset: Int): Boolean = @@ -15,10 +15,10 @@ class HaskellTemplateContextType extends TemplateContextType("HASKELL_FILE", "Ha class HaskellPragmaTemplateContextType extends TemplateContextType("HASKELL_PRAGMA", "Pragma", classOf[HaskellTemplateContextType]) { override def isInContext(file: PsiFile, offset: Int): Boolean = { if (file.getFileType != HaskellFileType.Instance) return false - var element = file.findElementAt(offset) - if (element == null) element = file.findElementAt(offset - 1) - if (element == null) return false - PsiTreeUtil.getParentOfType(element, classOf[HaskellFileHeader]) != null + if (offset < 5) return true + val element = file.findElementAt(offset - 5) + element != null && + PsiTreeUtil.getParentOfType(element, classOf[HaskellFileHeader]) != null } } @@ -27,8 +27,7 @@ class HaskellGlobalDefinitionTemplateContextType extends TemplateContextType("HA if (file.getFileType != HaskellFileType.Instance) return false var element = file.findElementAt(offset) if (element == null) element = file.findElementAt(offset - 1) - if (element == null) return false - PsiTreeUtil.getParentOfType(element, classOf[HaskellModuleBody]) != null && + element != null && PsiTreeUtil.getParentOfType(element, classOf[HaskellExpression]) == null } }